{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from esm import pretrained\n",
    "import pysam\n",
    "from easydict import EasyDict as edict\n",
    "from CLEAN.utils import *\n",
    "from CLEAN.evaluate import *\n",
    "from CLEAN.model import LayerNormNet\n",
    "from CLEAN.distance_map import *\n",
    "import pandas as pd \n",
    "from tqdm import tqdm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "config = {\n",
    "    \"train_data\": \"split100\",\n",
    "    \"inference_fasta_folder\": \"data\",\n",
    "    \"inference_fasta\": \"new.fasta\",\n",
    "    \"gpu_id\": 0, # None if use cpu\n",
    "    \"inference_fasta_start\": 0,\n",
    "    \"inference_fasta_end\": 300,\n",
    "    \"toks_per_batch\": 2048,\n",
    "    \"esm_type\": \"esm1b_t33_650M_UR50S\",\n",
    "    \"truncation_seq_length\": 1022, \n",
    "    \"esm_batches_per_clean_inference\": 200,\n",
    "    \"gmm\": \"./data/pretrained/gmm_ensumble.pkl\"\n",
    "}\n",
    "args = edict(config)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "class CustomFastaBatchedDataset(object):\n",
    "    def __init__(self, fasta_obj, fasta_start=0, fasta_end=None):\n",
    "        \"\"\"\n",
    "        Initialize the dataset from a pysam.FastaFile object.\n",
    "\n",
    "        Parameters:\n",
    "            fasta_obj (pysam.FastaFile): The pysam FASTA file object.\n",
    "            fasta_start (int, optional): Start index (inclusive) for slicing references. Defaults to 0.\n",
    "            fasta_end (int, optional): End index (exclusive) for slicing references.\n",
    "                                       If None, all sequences from fasta_start onward are used.\n",
    "        \"\"\"\n",
    "        # Get all references from the FASTA object.\n",
    "        refs = fasta_obj.references\n",
    "        if fasta_end is None:\n",
    "            fasta_end = len(refs)\n",
    "\n",
    "        self.sequence_indices_labels_dict = {\n",
    "            i:label for i, label in zip(\n",
    "                range(len(refs)), refs)\n",
    "        }\n",
    "        self.sequence_labels_indices_dict = {\n",
    "            label:i for i, label in zip(\n",
    "                range(len(refs)), refs)\n",
    "        }\n",
    "        # Slice the list of references.\n",
    "        subset_refs = refs[fasta_start:fasta_end]\n",
    "        # Populate the sequence labels and the sequence strings.\n",
    "        self.sequence_labels = list(subset_refs)\n",
    "        self.sequence_strs = [fasta_obj.fetch(ref) for ref in subset_refs]\n",
    "        \n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.sequence_labels)\n",
    "\n",
    "    def __getitem__(self, idx):\n",
    "        return self.sequence_labels[idx], self.sequence_strs[idx]\n",
    "\n",
    "    def get_batch_indices(self, toks_per_batch, extra_toks_per_seq=0):\n",
    "        sizes = [(len(s), i) for i, s in enumerate(self.sequence_strs)]\n",
    "        sizes.sort()\n",
    "        batches = []\n",
    "        buf = []\n",
    "        max_len = 0\n",
    "\n",
    "        def _flush_current_buf():\n",
    "            nonlocal max_len, buf\n",
    "            if len(buf) == 0:\n",
    "                return\n",
    "            batches.append(buf)\n",
    "            buf = []\n",
    "            max_len = 0\n",
    "\n",
    "        for sz, i in sizes:\n",
    "            sz += extra_toks_per_seq\n",
    "            if max(sz, max_len) * (len(buf) + 1) > toks_per_batch:\n",
    "                _flush_current_buf()\n",
    "            max_len = max(max_len, sz)\n",
    "            buf.append(i)\n",
    "\n",
    "        _flush_current_buf()\n",
    "        return batches\n",
    "\n",
    "\n",
    "def get_last_layer_emb(args, model, toks, strs, repr_layers):\n",
    "    if torch.cuda.is_available() and not args.gpu_id:\n",
    "        toks = toks.to(device=\"cuda\", non_blocking=True)\n",
    "    with torch.no_grad():\n",
    "        out = model(toks, repr_layers=repr_layers, return_contacts=False)\n",
    "        representations = {\n",
    "            layer: t.to(device=\"cpu\") for layer, t in out[\"representations\"].items()\n",
    "        }\n",
    "        representations = representations[repr_layers[0]] # only use the last layer\n",
    "        truncate_lens = [min(args.truncation_seq_length, len(strs_i)) for strs_i in strs]\n",
    "        embeddings = [emb[1 : truncate_len + 1].mean(0).clone() for emb,truncate_len \n",
    "                        in zip(representations, truncate_lens)]\n",
    "    return embeddings\n",
    "\n",
    "def get_max_sep_predictions_dict(inference_df, gmm):\n",
    "    max_sep_predictions = {}\n",
    "    for sequence_label in inference_df.columns:\n",
    "        smallest_10_dist_df = inference_df[sequence_label].nsmallest(10)\n",
    "        dist_lst = list(smallest_10_dist_df)\n",
    "        max_sep_i = maximum_separation(dist_lst, True, False)\n",
    "        ec = []\n",
    "        for i in range(max_sep_i+1):\n",
    "            EC_i = smallest_10_dist_df.index[i]\n",
    "            dist_i = smallest_10_dist_df[i]\n",
    "            if gmm != None:\n",
    "                gmm_lst = pickle.load(open(gmm, 'rb'))\n",
    "                dist_i = infer_confidence_gmm(dist_i, gmm_lst)\n",
    "            dist_str = \"{:.4f}\".format(dist_i)\n",
    "            ec.append('EC:' + str(EC_i) + '/' + dist_str)\n",
    "        max_sep_predictions[sequence_label] = ec\n",
    "    return max_sep_predictions\n",
    "\n",
    "\n",
    "def CLEAN_max_sep_predictions(\n",
    "        args,  CLEAN_model, sequence_label_esm_emb_dict, emb_train, ec_id_dict_train, device):\n",
    "    esm_emb_inference = torch.cat(\n",
    "        [sequence_label_esm_emb_dict[label].unsqueeze(0) \n",
    "            for label in sequence_label_esm_emb_dict])\n",
    "    id_ec_inference_dummy = {seq_label:[] for seq_label in sequence_label_esm_emb_dict}\n",
    "    with torch.no_grad():\n",
    "        model_emb_inference  = CLEAN_model(esm_emb_inference.to(device)).to(\"cpu\").clone()\n",
    "    inference_dist = get_dist_map_test(\n",
    "        emb_train, model_emb_inference, ec_id_dict_train, id_ec_inference_dummy, \"cpu\", torch.float32)\n",
    "    inference_df = pd.DataFrame.from_dict(inference_dist)\n",
    "    max_sep_predictions_dict = get_max_sep_predictions_dict(inference_df, args.gmm)\n",
    "    return max_sep_predictions_dict"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "inference_fasta_path = f'{args.inference_fasta_folder}/{args.inference_fasta}'\n",
    "inference_fasta = pysam.FastaFile(inference_fasta_path)\n",
    "dataset = CustomFastaBatchedDataset(\n",
    "    inference_fasta, fasta_start=args.inference_fasta_start, fasta_end=args.inference_fasta_end)\n",
    "# keep track of index label mappings in the original fasta\n",
    "sequence_indices_labels_dict = dataset.sequence_indices_labels_dict\n",
    "sequence_labels_indices_dict = dataset.sequence_labels_indices_dict\n",
    "batches = dataset.get_batch_indices(args.toks_per_batch, extra_toks_per_seq=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "# load ESM model\n",
    "model, alphabet = pretrained.load_model_and_alphabet(args.esm_type)\n",
    "repr_layers = [(i + model.num_layers + 1) % (model.num_layers + 1) for i in [-1]] # only use the last layer\n",
    "data_loader = torch.utils.data.DataLoader(\n",
    "    dataset, collate_fn=alphabet.get_batch_converter(args.truncation_seq_length), batch_sampler=batches\n",
    ")\n",
    "if torch.cuda.is_available() and not args.gpu_id:\n",
    "    import os\n",
    "    os.environ[\"CUDA_VISIBLE_DEVICES\"] = str(args.gpu_id)\n",
    "    model = model.cuda()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "## load CLEAN model\n",
    "if torch.cuda.is_available() and not args.gpu_id:\n",
    "    device = \"cuda\"\n",
    "else: \n",
    "    device = \"cpu\"\n",
    "CLEAN_model = LayerNormNet(512, 128, device, torch.float32)\n",
    "checkpoint = torch.load('./data/pretrained/'+ args.train_data +'.pth', map_location=device)\n",
    "CLEAN_model.load_state_dict(checkpoint)\n",
    "CLEAN_model.eval()\n",
    "id_ec_train, ec_id_dict_train = get_ec_id_dict('./data/' + args.train_data + '.csv')\n",
    "emb_train = torch.load('./data/pretrained/100.pt', map_location=\"cpu\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "80it [00:28,  2.85it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The embedding sizes for train and test: torch.Size([241025, 128]) torch.Size([300, 128])\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 5242/5242 [00:02<00:00, 1932.39it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Calculating eval distance map, between 300 test ids and 5242 train EC cluster centers\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "300it [00:09, 32.92it/s]\n"
     ]
    }
   ],
   "source": [
    "sequence_label_esm_emb_dict = {}\n",
    "max_sep_predictions_dict = {} \n",
    "\n",
    "for batch_idx, (labels, strs, toks) in tqdm(enumerate(data_loader)):\n",
    "    embeddings = get_last_layer_emb(\n",
    "        args=args, model=model, toks=toks, strs=strs, repr_layers=repr_layers)\n",
    "    for label, emb in zip(labels, embeddings):\n",
    "        sequence_label_esm_emb_dict[label] = emb\n",
    "\n",
    "    if (batch_idx + 1) % args.esm_batches_per_clean_inference == 0:\n",
    "        # perform clean inference\n",
    "        predictions = CLEAN_max_sep_predictions(\n",
    "            args, CLEAN_model, sequence_label_esm_emb_dict, emb_train, ec_id_dict_train, device\n",
    "        )\n",
    "        max_sep_predictions_dict.update(predictions)\n",
    "        sequence_label_esm_emb_dict = {}\n",
    "        \n",
    "# process any remaining sequences that didn't complete a full batch group\n",
    "if sequence_label_esm_emb_dict:\n",
    "    predictions = CLEAN_max_sep_predictions(\n",
    "        args, CLEAN_model, sequence_label_esm_emb_dict, emb_train, ec_id_dict_train, device\n",
    "    )\n",
    "    max_sep_predictions_dict.update(predictions)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "max_sep_predictions_df = pd.DataFrame([\n",
    "    {'Index': sequence_labels_indices_dict[seq_id], 'Seq_ID': seq_id, 'Prediction': '; '.join(max_sep_predictions_dict[seq_id])}\n",
    "    for seq_id in max_sep_predictions_dict\n",
    "]).sort_values('Index').reset_index(drop=True)\n",
    "fasta_name = args.inference_fasta.split(\".\")[0]\n",
    "prediction_file_name = f\"results/inputs/{fasta_name}_{args.inference_fasta_start}_{args.inference_fasta_end}.csv\"\n",
    "max_sep_predictions_df.to_csv(prediction_file_name, index=False)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
